{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4f49c5c4-1170-42a7-9d6a-b90acd00c3c3",
   "metadata": {},
   "source": [
    "# RAFT IVF Flat Example Notebook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bcfe810-f120-422c-b2bb-72cc43d0c4ca",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "This notebook demonstrates how to run approximate nearest neighbor search using RAFT IVF-Flat algorithm.\n",
    "It builds and searches an index using a dataset from the ann-benchmarks million-scale datasets, saves/loads the index to disk, and explores important parameters for fine-tuning the search performance and accuracy of the index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "fe73ada7-7b7f-4005-9440-85428194311b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cupy as cp\n",
    "import numpy as np\n",
    "from pylibraft.common import DeviceResources\n",
    "from pylibraft.neighbors import ivf_flat\n",
    "import matplotlib.pyplot as plt\n",
    "import tempfile\n",
    "from utils import BenchmarkTimer, calc_recall, load_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da9e8615-ea9f-4735-b70f-15ccab36c0d9",
   "metadata": {},
   "source": [
    "For best performance it is recommended to use an RMM pooling allocator, to minimize the overheads of repeated CUDA allocations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5350e4d9-0993-406a-80af-29538b5677c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import rmm\n",
    "from rmm.allocators.cupy import rmm_cupy_allocator\n",
    "mr = rmm.mr.PoolMemoryResource(\n",
    "     rmm.mr.CudaMemoryResource(),\n",
    "     initial_pool_size=2**30\n",
    ")\n",
    "rmm.mr.set_current_device_resource(mr)\n",
    "cp.cuda.set_allocator(rmm_cupy_allocator)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0d935f2-ba24-44fc-bdfe-a769b7fcd8e6",
   "metadata": {},
   "source": [
    "The following GPU is used for this notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a5daa4b4-96de-4e74-bfd6-505b13595f62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu Sep 21 02:30:53 2023       \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |\n",
      "|-----------------------------------------+----------------------+----------------------+\n",
      "| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                                         |                      |               MIG M. |\n",
      "|=========================================+======================+======================|\n",
      "|   0  NVIDIA H100 PCIe               On  | 00000000:41:00.0 Off |                    0 |\n",
      "| N/A   35C    P0              69W / 350W |   1487MiB / 81559MiB |      0%      Default |\n",
      "|                                         |                      |             Disabled |\n",
      "+-----------------------------------------+----------------------+----------------------+\n",
      "                                                                                         \n",
      "+---------------------------------------------------------------------------------------+\n",
      "| Processes:                                                                            |\n",
      "|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |\n",
      "|        ID   ID                                                             Usage      |\n",
      "|=======================================================================================|\n",
      "|    0   N/A  N/A      3940      C   /opt/conda/envs/rapids/bin/python          1474MiB |\n",
      "+---------------------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "# Report the GPU in use\n",
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88a654cc-6389-4526-a3e6-826de5606a09",
   "metadata": {},
   "source": [
    "## Load dataset\n",
    "\n",
    "The ANN benchmarks website provides the datasets in HDF5 format.\n",
    "\n",
    "The list of prepared datasets can be found at https://github.com/erikbern/ann-benchmarks/#data-sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "5f529ad6-b0bd-495c-bf7c-43f10fb6aa14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The index and data will be saved in /tmp/raft_example\n"
     ]
    }
   ],
   "source": [
    "WORK_FOLDER = os.path.join(tempfile.gettempdir(), \"raft_example\")\n",
    "f = load_dataset(\"http://ann-benchmarks.com/sift-128-euclidean.hdf5\", work_folder=WORK_FOLDER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3d68a7db-bcf4-449c-96c3-1e8ab146c84d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded dataset of size (1000000, 128),  0.5 GiB; metric: 'euclidean'.\n",
      "Number of test queries: 10000\n"
     ]
    }
   ],
   "source": [
    "metric = f.attrs['distance']\n",
    "\n",
    "dataset = cp.array(f['train'])\n",
    "queries = cp.array(f['test'])\n",
    "gt_neighbors = cp.array(f['neighbors'])\n",
    "gt_distances = cp.array(f['distances'])\n",
    "\n",
    "itemsize = dataset.dtype.itemsize \n",
    "\n",
    "print(f\"Loaded dataset of size {dataset.shape}, {dataset.size*itemsize/(1<<30):4.1f} GiB; metric: '{metric}'.\")\n",
    "print(f\"Number of test queries: {queries.shape[0]}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f463c50-d1d3-49be-bcfe-952602efa603",
   "metadata": {},
   "source": [
    "## Build index\n",
    "We set [IndexParams](https://docs.rapids.ai/api/raft/nightly/pylibraft_api/neighbors/#pylibraft.neighbors.ivf_flat.IndexParams) and build the index. The index parameters will be discussed in more detail in later sections of this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "737f8841-93f9-4c8e-b2e1-787d4474ef94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 120 ms, sys: 5.33 ms, total: 125 ms\n",
      "Wall time: 124 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "build_params = ivf_flat.IndexParams(\n",
    "        n_lists=1024,\n",
    "        metric=\"euclidean\",\n",
    "        kmeans_trainset_fraction=0.1,\n",
    "        kmeans_n_iters=20,\n",
    "        add_data_on_build=True\n",
    "    )\n",
    "\n",
    "index = ivf_flat.build(build_params, dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a16a0cf6-3b05-4afd-9bb8-54431e0d7439",
   "metadata": {},
   "source": [
    "The index is built. We can print some basic information of the index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1aec7024-6e5d-4d2c-82e6-7b5734aec958",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(type=IVF-FLAT, metric=euclidean, size=1000000, dim=128, n_lists=1024, adaptive_centers=False)\n"
     ]
    }
   ],
   "source": [
    "print(index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df7d4958-56a3-48ea-bd64-3486fdb57fb7",
   "metadata": {},
   "source": [
    "## Search neighbors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89ba2eaa-4c85-4e1c-b07c-920394e55dce",
   "metadata": {},
   "source": [
    "It is recommended to reuse [device recosources](https://docs.rapids.ai/api/raft/nightly/pylibraft_api/common/#pylibraft.common.DeviceResources) across multiple invocations of search, since constructing these can be time consuming. We will reuse the resources by passing the same handle to each  RAFT API call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "46e0421b-9335-47a2-8451-a91f56c2f086",
   "metadata": {},
   "outputs": [],
   "source": [
    "handle = DeviceResources()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6365229-18fd-468f-af30-e24b950cbd6e",
   "metadata": {},
   "source": [
    "After setting [SearchParams](https://docs.rapids.ai/api/raft/nightly/pylibraft_api/neighbors/#pylibraft.neighbors.ivf_flat.SearchParams) we search for for `k=10` neighbors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "595454e1-7240-4b43-9a73-963d5670b00c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 171 ms, sys: 52.6 ms, total: 224 ms\n",
      "Wall time: 236 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "n_queries=10000\n",
    "# n_probes is the number of clusters we select in the first (coarse) search step. This is the only hyper parameter for search.\n",
    "search_params = ivf_flat.SearchParams(n_probes=30)\n",
    "\n",
    "# Search 10 nearest neighbors.\n",
    "distances, indices = ivf_flat.search(search_params, index, cp.asarray(queries[:n_queries,:]), k=10, handle=handle)\n",
    "    \n",
    "# RAFT calls are asynchronous (when handle arg is provided), we need to sync before accessing the results.\n",
    "handle.sync()\n",
    "distances, neighbors = cp.asnumpy(distances), cp.asnumpy(indices)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43d20ca7-7b9e-4046-bb52-640a2744db75",
   "metadata": {},
   "source": [
    "The returned arrays have shape {n_queries x 10] and store the distance values and the indices of the searched vectors. We check how accurate the search is. The accuracy of the search is quantified as `recall`, which is a value between 0 and 1 and tells us what fraction of the returned neighbors are actual k nearest neighbors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8cd9cd20-ca00-4a35-a0a0-86636521b31a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.97406"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "calc_recall(neighbors, gt_neighbors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cde5079c-9777-45a1-9545-cffbcc59988f",
   "metadata": {},
   "source": [
    "## Save and load the index\n",
    "You can serialize the index to file using [save](https://docs.rapids.ai/api/raft/nightly/pylibraft_api/neighbors/#pylibraft.neighbors.ivf_flat.save), and [load](https://docs.rapids.ai/api/raft/nightly/pylibraft_api/neighbors/#pylibraft.neighbors.ivf_flat.load) it later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "bf94e45c-e7fb-4aa3-a611-ddaee7ac41ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_file = os.path.join(WORK_FOLDER, \"my_ivf_flat_index.bin\")\n",
    "ivf_flat.save(index_file, index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "1622d9be-be41-4d25-be99-d348c5e54957",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = ivf_flat.load(index_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15d503e5-05e8-47ce-8501-e13fc512099c",
   "metadata": {},
   "source": [
    "## Tune search parameters\n",
    "Search has a single hyper parameter: `n_probes`, which describes how many neighboring cluster is searched (probed) for each query. Within a probed cluster, the distance is computed between all the vectors in the cluster and the query point, and the top-k neighbors are selected. Finally, the top-k neighbors are selected from all the neighbor candidates from the probed clusters.\n",
    "\n",
    "Let's see how search accuracy and latency changes when we change the `n_probes` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ace0c31f-af75-4352-a438-123a9a03612c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Benchmarking search with n_probes = 10\n",
      "recall 0.86625\n",
      "Average search time:   0.026 +/- 0.000259 s\n",
      "Queries per second (QPS):   384968\n",
      "\n",
      "Benchmarking search with n_probes = 20\n",
      "recall 0.94705\n",
      "Average search time:   0.050 +/- 5.43e-05 s\n",
      "Queries per second (QPS):   198880\n",
      "\n",
      "Benchmarking search with n_probes = 30\n",
      "recall 0.97406\n",
      "Average search time:   0.075 +/- 8.59e-05 s\n",
      "Queries per second (QPS):   133954\n",
      "\n",
      "Benchmarking search with n_probes = 50\n",
      "recall 0.99169\n",
      "Average search time:   0.123 +/- 4.78e-05 s\n",
      "Queries per second (QPS):    80997\n",
      "\n",
      "Benchmarking search with n_probes = 100\n",
      "recall 0.99844\n",
      "Average search time:   0.244 +/- 0.000249 s\n",
      "Queries per second (QPS):    40934\n",
      "\n",
      "Benchmarking search with n_probes = 200\n",
      "recall 0.99932\n",
      "Average search time:   0.468 +/- 0.000367 s\n",
      "Queries per second (QPS):    21382\n",
      "\n",
      "Benchmarking search with n_probes = 500\n",
      "recall 0.99933\n",
      "Average search time:   1.039 +/- 0.000209 s\n",
      "Queries per second (QPS):     9625\n",
      "\n",
      "Benchmarking search with n_probes = 1024\n",
      "recall 0.99935\n",
      "Average search time:   0.701 +/- 0.00579 s\n",
      "Queries per second (QPS):    14273\n"
     ]
    }
   ],
   "source": [
    "n_probes = np.asarray([10, 20, 30, 50, 100, 200, 500, 1024]);\n",
    "qps = np.zeros(n_probes.shape);\n",
    "recall = np.zeros(n_probes.shape);\n",
    "\n",
    "for i in range(len(n_probes)):\n",
    "    print(\"\\nBenchmarking search with n_probes =\", n_probes[i])\n",
    "    timer = BenchmarkTimer(reps=1, warmup=1)\n",
    "    for rep in timer.benchmark_runs():\n",
    "        distances, neighbors = ivf_flat.search(\n",
    "            ivf_flat.SearchParams(n_probes=n_probes[i]),\n",
    "            index,\n",
    "            cp.asarray(queries),\n",
    "            k=10,\n",
    "            handle=handle,\n",
    "        )\n",
    "        handle.sync()\n",
    "    \n",
    "    recall[i] = calc_recall(cp.asnumpy(neighbors), gt_neighbors)\n",
    "    print(\"recall\", recall[i])\n",
    "\n",
    "    timings = np.asarray(timer.timings)\n",
    "    avg_time = timings.mean()\n",
    "    std_time = timings.std()\n",
    "    qps[i] = queries.shape[0] / avg_time\n",
    "    print(\"Average search time: {0:7.3f} +/- {1:7.3} s\".format(avg_time, std_time))\n",
    "    print(\"Queries per second (QPS): {0:8.0f}\".format(qps[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20b2498c-7231-4211-990e-600d5c26a9a1",
   "metadata": {},
   "source": [
    "The plots below illustrate how the accuracy (recall) and the throughput (queries per second) depends on the `n_probes` parameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ac370f-91c8-4054-95c7-a749df5f16d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(12,3))\n",
    "ax = fig.add_subplot(131)\n",
    "ax.plot(n_probes, recall,'o-')\n",
    "#ax.set_xticks(bench_k, bench_k)\n",
    "ax.set_xlabel('n_probes')\n",
    "ax.grid()\n",
    "ax.set_ylabel('recall (@k=10)')\n",
    "\n",
    "ax = fig.add_subplot(132)\n",
    "ax.plot(n_probes, qps,'o-')\n",
    "#ax.set_xticks(bench_k, bench_k)\n",
    "ax.set_xlabel('n_probes')\n",
    "ax.grid()\n",
    "ax.set_ylabel('queries per second');\n",
    "\n",
    "ax = fig.add_subplot(133)\n",
    "ax.plot(recall, qps,'o-')\n",
    "#ax.set_xticks(bench_k, bench_k)\n",
    "ax.set_xlabel('recall')\n",
    "ax.grid()\n",
    "ax.set_ylabel('queries per second');\n",
    "#ax.set_yscale('log')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81e7ad6a-bddc-45de-9cce-0fb913f91efe",
   "metadata": {},
   "source": [
    "## Adjust build parameters\n",
    "### n_lists\n",
    "The number of clusters (or lists) is set by the n_list parameter. Let's change it to 100 clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "addbfff3-7773-4290-9608-5489edf4886d",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "build_params = ivf_flat.IndexParams(\n",
    "        n_lists=100,\n",
    "        metric=\"euclidean\",\n",
    "        kmeans_trainset_fraction=1,\n",
    "        kmeans_n_iters=20,\n",
    "        add_data_on_build=True\n",
    "    )\n",
    "\n",
    "index = ivf_flat.build(build_params, dataset, handle=handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48db27f9-54c8-4dac-839b-af94ada8885f",
   "metadata": {},
   "source": [
    "The ratio of n_probes / n_list will determine how large fraction of the dataset is searched for each query. The right combination depends on the use case. Here we will search 10 of the clusters for each query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a0149ad-de38-4195-97a5-ce5d5d877036",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "n_queries=10000\n",
    "\n",
    "search_params = ivf_flat.SearchParams(n_probes=10)\n",
    "\n",
    "# Search 10 nearest neighbors.\n",
    "distances, indices = ivf_flat.search(search_params, index, cp.asarray(queries[:n_queries,:]), k=10, handle=handle)\n",
    "    \n",
    "handle.sync()\n",
    "distances, neighbors = cp.asnumpy(distances), cp.asnumpy(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eedc3ec4-06af-42c5-8cdf-490a5c2bc49a",
   "metadata": {},
   "outputs": [],
   "source": [
    "calc_recall(neighbors, gt_neighbors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c44800f-1e9e-4f7b-87fe-0f25e6590faa",
   "metadata": {},
   "source": [
    "### trainset_fraction\n",
    "During clustering we can sub-sample the dataset. The parameter `trainset_fraction` determines what fraction to use. Often we get good results by using only 1/10th of the dataset for clustering. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a54d190-64d4-4cd4-a497-365cbffda871",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "build_params = ivf_flat.IndexParams( \n",
    "        n_lists=100, \n",
    "        metric=\"sqeuclidean\", \n",
    "        kmeans_trainset_fraction=0.1, \n",
    "        kmeans_n_iters=20 \n",
    "    ) \n",
    "index = ivf_flat.build(build_params, dataset, handle=handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d86a213-d6ae-4fca-9082-cb5a4d1dab36",
   "metadata": {},
   "source": [
    "We see only a minimal change in the recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc992e8-a5e5-4508-b790-0e934160b660",
   "metadata": {},
   "outputs": [],
   "source": [
    "search_params = ivf_flat.SearchParams(n_probes=10)\n",
    "\n",
    "distances, indices = ivf_flat.search(search_params, index, cp.asarray(queries[:n_queries,:]), k=10, handle=handle)\n",
    "    \n",
    "handle.sync()\n",
    "distances, neighbors = cp.asnumpy(distances), cp.asnumpy(indices)\n",
    "calc_recall(neighbors, gt_neighbors)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25289ebc-7d89-4fa6-bc62-e25b6e77750c",
   "metadata": {},
   "source": [
    "### Add vectors on build\n",
    "Currently you cannot configure how RAFT sub-samples the input. If you want to have a fine control on how the training set is selected, then create the index in two steps:\n",
    "1. Define cluster centers on a training set, but do not add any vector to the index\n",
    "2. Add vectors to the index (extend)\n",
    "\n",
    "This workflow shall be familiar to FAISS users. Note that raft does not require adding the data in batches, internal batching is used when necessary.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ebcf970-94ed-4825-9885-277bd984b90c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# subsample the dataset\n",
    "n_train = 10000\n",
    "train_set = dataset[cp.random.choice(dataset.shape[0], n_train, replace=False),:]\n",
    "\n",
    "# build using training set\n",
    "build_params = ivf_flat.IndexParams(\n",
    "        n_lists=1024,\n",
    "        metric=\"sqeuclidean\",\n",
    "        kmeans_trainset_fraction=1,\n",
    "        kmeans_n_iters=20,\n",
    "        add_data_on_build=False\n",
    "    )\n",
    "index = ivf_flat.build(build_params, train_set)\n",
    "\n",
    "print(\"Index before adding vectors\", index)\n",
    "\n",
    "ivf_flat.extend(index, dataset, cp.arange(dataset.shape[0], dtype=cp.int64))\n",
    "\n",
    "print(\"Index after adding vectors\", index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "029d48a9-baf7-4263-af43-9e500ef3cce4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
